#!/usr/bin/python
# -*- coding: utf-8 -*-

# TODO:
# - output cseg

from pylab import *
from collections import Counter,defaultdict
import glob,re,heapq,os,cPickle
import ocrolib
from ocrolib import ngraphs as ng

def method(cls):
    """Adds the function as a method to the given class."""
    import new
    def _wrap(f):
        cls.__dict__[f.func_name] = new.instancemethod(f,None,cls)
        return None
    return _wrap

class Edge:
    def __init__(self,**kw):
        self.__dict__.update(kw)
@method(Edge)
def __str__(self):
    return self.__repr__()
@method(Edge)
def __repr__(self):
    return "<%d:%d [%s] %.2f %d:%d>"%(self.start,self.stop,self.cls,self.cost,self.seg[0],self.seg[1])

class Lattice:
    def __init__(self,**kw):
        self.maxws = 20.0
        self.maxcost = 20.0
        self.mismatch = 30.0
        self.accept = None
        self.__dict__.update(kw)

@method(Lattice)
def addEdge(self,start=None,stop=None,cost=None,cls=None,seg=(0,0)):
    self.states.add(start)
    self.states.add(stop)
    self.edges[start].append(Edge(start=start,stop=stop,cost=cost,cls=cls,seg=seg))

@method(Lattice)
def readLattice(self,fname):
    self.states = set()
    self.edges = defaultdict(list)
    with open(fname) as stream:
        for line in stream.readlines():
            f = line.split()
            if f[0]=="segment":
                first,last = [int(x) for x in f[2].split(":")]
                # we put the actual OCR segment numbers at 10x the state;
                # that gives us intermediate states to insert spaces and extra
                # characters
                st_start = 2*first
                st_extra = 2*last+1
                st_next = 2*last+2
                ws,nows = [float(x) for x in f[4:6]]
                ws = minimum(ws,self.maxws)
                nows = minimum(nows,self.maxws)
                if self.edges[st_extra]==[]:
                    # skip or replace
                    self.addEdge(start=st_start,stop=st_extra,cost=self.mismatch,cls="")
                    self.addEdge(start=st_start,stop=st_extra,cost=self.mismatch,cls="~")
                    # insert space / no space
                    self.addEdge(start=st_extra,stop=st_next,cost=ws,cls=" ")
                    self.addEdge(start=st_extra,stop=st_next,cost=nows,cls="")
                    # insert arbitrary (this implies "no space")
                    self.addEdge(start=st_extra,stop=st_next,cost=self.mismatch,cls="~")
            elif f[0]=="chr":
                cost = minimum(float(f[3])+nows,self.maxcost)
                self.addEdge(start=st_start,stop=st_extra,cost=cost+nows,cls=f[4],seg=(first,last))
    return self

@method(Lattice)
def isAccept(self,i):
    if self.accept is None:
        self.accept = [self.lastState()]
    return i in self.accept

@method(Lattice)
def latticeGraph(self):
    import pydot
    graph = pydot.Dot("lattice",graph_type="digraph",rankdir="LR")
    for s in sorted(list(self.states)):
        graph.add_node(pydot.Node(str(s)))
    for s in sorted(list(self.states)):
        for edge in self.edges[s]:
            cls = edge.cls
            if cls in ["'",'"']: cls = "<QUOTE>"
            elif cls=="": cls = "<EPS>"
            elif cls==" ": cls = "<SPC>"
            graph.add_edge(pydot.Edge(src=str(edge.start),dst=str(edge.stop),label='"%s/%s"'%(cls,edge.cost)))
    return graph

@method(Lattice)
def showLattice(self):
    graph = lattice.latticeGraph()
    with open("temp.png","w") as stream: 
        stream.write(graph.create_png())
    os.system("eog temp.png&")
    
@method(Lattice)
def startState(self):
    return min(self.states)

@method(Lattice)
def lastState(self):
    return max(self.states)

@method(Lattice)
def classes(self):
    edges = reduce(lambda x,y:x+y,[[e for e in l] for k,l in self.edges.items()])
    classes = set([e.cls for e in edges])
    return sorted(list(classes))

class Path:
    def __init__(self,cost=0.0,state=-1,path="",sequence=[]):
        self.cost = cost
        self.state = state
        self.path = path
        self.sequence = sequence

@method(Path)
def __repr__(self):
    return "<Path %.2f %d '%s'>"%(self.cost,self.state,self.path)

@method(Path)
def __str__(self):
    return self.__repr__()

@method(Path)
def __cmp__(self,other):
    return cmp((self.cost,self.state,self.path),(other.cost,other.state,other.path))

def expand(state,ngraphs,cweight=1.0,lweight=1.0,other=10.0,missing=30.0,verbose=0,rank=-1,thresh=1.0):
    missing = {"~":missing}
    N = ngraphs.N
    prefix = ng.lineproc(state.path)[-N+1:]
    posteriors = ngraphs.posteriors.get(prefix,missing)
    edges = lattice.edges[state.state]
    floor = posteriors["~"]
    result = []

    # add all the transitions for which we have edges
    for e in edges:
        if verbose: print e
        assert e.start==state.state
        # we use a fixed up class to match the fixed up language model
        cls = ng.lineproc(e.cls)
        l = 0.0 if e.cost<thresh and e.cls!=" " else lweight
        if len(cls)==0:
            ncost = state.cost + cweight*e.cost
        elif e.cls=="~":
            ncost = state.cost + cweight*e.cost
        elif len(cls)==1:
            ncost = state.cost + cweight*e.cost + l*posteriors.get(cls,floor)
        else:
            tpath = state.path + cls
            lcost = ngraphs.posteriors.get(tpath[-N+1:],missing).get(tpath[-1],floor)
            ncost = state.cost + cweight*e.cost + l*lcost
        nsequence = state.sequence + [e]
        npath = state.path + e.cls
        nstate = e.stop
        assert nstate>state.state,("oops: %s %s %s %s"%(e.start,e.stop,cls,e.cost))
        result.append(Path(cost=ncost,state=nstate,path=npath,sequence=nsequence))

    return result        

def eliminate_common_suffixes_and_sort(paths,n):
    # sort by cost
    paths = sorted(paths)
    # keep track of the best
    result = {}
    for p in paths:
        suffix = p.path[-n:]
        if suffix in result: continue
        result[suffix] = p
    return sorted(result.values())

def search(lattice,ngraphs,accept=None,verbose=0,beam=100,**kw):
    global table
    N = ngraphs.N
    initial = Path(0.0,lattice.startState(),"_"*N)
    nstates = lattice.lastState()+1
    table = [[] for i in range(nstates)]
    table[initial.state] = [initial]

    for i in range(nstates):
        if lattice.isAccept(i): break
        if len(table[i])==0: continue
        table[i] = eliminate_common_suffixes_and_sort(table[i],n=ngraphs.N)
        if verbose>0: print i,table[i][0]
        for rank,s in enumerate(table[i][:beam]):
            expanded = expand(s,ngraphs,rank=rank,**kw)
            for e in expanded:
                if verbose>1: print "    ",e
                table[e.state].append(e)
        # table[i] = None

    result = eliminate_common_suffixes_and_sort(table[i],n=ngraphs.N)
    return result

import argparse

parser = argparse.ArgumentParser()

parser.add_argument('--show',default=None,help="show the argument lattice")
parser.add_argument('--build',default=None,help="build and write a language model")
parser.add_argument('--ngraph',type=int,default=4,help="order of the language model")
parser.add_argument('--sample',default=None,type=int,help="sample from the language model")
parser.add_argument('--slength',default=70,type=int,help="length of the sampled strings")
parser.add_argument('-l','--lmodel',default=ocrolib.default.ngraphs,help="the language model")
parser.add_argument('-L','--lweight',default=0.5,type=float,help="language model weight")
parser.add_argument('-C','--cweight',default=1.0,type=float,help="character weight")
parser.add_argument('-B','--beam',default=10,type=int,help="beam width")
parser.add_argument('-W','--maxws',default=5.0,type=float,help="max whitespace cost")
parser.add_argument('-M','--maxcost',default=20.0,type=float,help="max cost")
parser.add_argument('-X','--mismatch',default=30.0,type=float,help="mismatch cost")
parser.add_argument('-T','--thresh',default=0.5,type=float,help="below this cost, ignore language model")

parser.add_argument('files',nargs='*')
args = parser.parse_args()
files = args.files

args.lmodel = ocrolib.findfile(args.lmodel)
assert os.path.exists(args.lmodel),\
       "%s: cannot find language model"%args.lmodel

if args.show is not None:
    lattice = Lattice(maxws=args.maxws,maxcost=args.maxcost,mismatch=args.mismatch)
    lattice.readLattice(args.show)
    lattice.showLattice()
    sys.exit(0)

if args.build is not None:
    fnames = []
    for pattern in args.files:
        if "=" in pattern:
            fnames += [pattern]
            continue
        l = glob.glob(pattern)
        assert len(l)>0,"%s: didn't expand to any files"%pattern
        for f in l:
            assert ".lattice" not in f
            assert ".png" not in f
        fnames += l
    print "got",len(fnames),"files"
    ngraphs = ng.NGraphs()
    ngraphs.buildFromFiles(fnames,args.ngraph)
    with open(args.build,"w") as stream:
        cPickle.dump(ngraphs,stream,2)
    sys.exit(0)

if args.sample is not None:
    with open(args.lmodel) as stream:
        ngraphs = cPickle.load(stream)
    for i in range(args.sample):
        print ngraphs.sample(args.slength)
    sys.exit(0)
    

fnames = []
for pattern in args.files:
    l = sorted(glob.glob(pattern))
    for f in l:
        assert ".lattice" in f,"all files must end with .lattice"
    fnames += l

assert len(fnames)>0,"must provide some filename arguments"

with open(args.lmodel) as stream:
    ngraphs = cPickle.load(stream)

for fname in fnames:
    lattice = Lattice(maxws=args.maxws,maxcost=args.maxcost,mismatch=args.mismatch)
    lattice.readLattice(fname)
    result = search(lattice,ngraphs,lweight=args.lweight,cweight=args.cweight,beam=args.beam,thresh=args.thresh)
    text = result[0].path[ngraphs.N:]
    print "%s\t%8.2f\t%s"%(fname,result[0].cost,text)
    fout = ocrolib.fvariant(fname,"txt")
    ocrolib.write_text(fout,text)
